#!/usr/bin/python
# /// script
# requires-python = ">=3.11"
# dependencies = []
# ///

from __future__ import annotations

import re
from dataclasses import dataclass
from pathlib import Path
from subprocess import check_output
from typing import Any

import tomllib

# Types that require `crate::`. We can slowly remove these types as we move them to generate scripts.
types_requiring_crate_prefix = {
    "IpyEscapeKind",
    "ExprContext",
    "Identifier",
    "Number",
    "BytesLiteralValue",
    "StringLiteralValue",
    "FStringValue",
    "TStringValue",
    "Arguments",
    "CmpOp",
    "Comprehension",
    "DictItem",
    "UnaryOp",
    "BoolOp",
    "Operator",
    "Decorator",
    "TypeParams",
    "Parameters",
    "ElifElseClause",
    "WithItem",
    "MatchCase",
    "Alias",
    "Singleton",
    "PatternArguments",
}


@dataclass
class VisitorInfo:
    name: str
    accepts_sequence: bool = False


# Map of AST node types to their corresponding visitor information.
# Only visitors that are different from the default `visit_*` method are included.
# These visitors either have a different name or accept a sequence of items.
type_to_visitor_function: dict[str, VisitorInfo] = {
    "TypeParams": VisitorInfo("visit_type_params", True),
    "Parameters": VisitorInfo("visit_parameters", True),
    "Stmt": VisitorInfo("visit_body", True),
    "Arguments": VisitorInfo("visit_arguments", True),
}


def rustfmt(code: str) -> str:
    return check_output(["rustfmt", "--emit=stdout"], input=code, text=True)


def to_snake_case(node: str) -> str:
    """Converts CamelCase to snake_case"""
    return re.sub("([A-Z])", r"_\1", node).lower().lstrip("_")


def write_rustdoc(out: list[str], doc: str) -> None:
    for line in doc.split("\n"):
        out.append(f"/// {line}")


# ------------------------------------------------------------------------------
# Read AST description


def load_ast(root: Path) -> Ast:
    ast_path = root.joinpath("crates", "ruff_python_ast", "ast.toml")
    with ast_path.open("rb") as ast_file:
        ast = tomllib.load(ast_file)
    return Ast(ast)


# ------------------------------------------------------------------------------
# Preprocess


@dataclass
class Ast:
    """
    The parsed representation of the `ast.toml` file. Defines all of the Python
    AST syntax nodes, and which groups (`Stmt`, `Expr`, etc.) they belong to.
    """

    groups: list[Group]
    ungrouped_nodes: list[Node]
    all_nodes: list[Node]

    def __init__(self, ast: dict[str, Any]) -> None:
        self.groups = []
        self.ungrouped_nodes = []
        self.all_nodes = []
        for group_name, group in ast.items():
            group = Group(group_name, group)
            self.all_nodes.extend(group.nodes)
            if group_name == "ungrouped":
                self.ungrouped_nodes = group.nodes
            else:
                self.groups.append(group)


@dataclass
class Group:
    name: str
    nodes: list[Node]
    owned_enum_ty: str

    add_suffix_to_is_methods: bool
    anynode_is_label: str
    doc: str | None

    def __init__(self, group_name: str, group: dict[str, Any]) -> None:
        self.name = group_name
        self.owned_enum_ty = group_name
        self.ref_enum_ty = group_name + "Ref"
        self.add_suffix_to_is_methods = group.get("add_suffix_to_is_methods", False)
        self.anynode_is_label = group.get("anynode_is_label", to_snake_case(group_name))
        self.doc = group.get("doc")
        self.nodes = [
            Node(self, node_name, node) for node_name, node in group["nodes"].items()
        ]


@dataclass
class Node:
    name: str
    variant: str
    ty: str
    doc: str | None
    fields: list[Field] | None
    derives: list[str]
    custom_source_order: bool
    source_order: list[str] | None

    def __init__(self, group: Group, node_name: str, node: dict[str, Any]) -> None:
        self.name = node_name
        self.variant = node.get("variant", node_name.removeprefix(group.name))
        self.ty = f"crate::{node_name}"
        self.fields = None
        fields = node.get("fields")
        if fields is not None:
            self.fields = [Field(f) for f in fields]
        self.custom_source_order = node.get("custom_source_order", False)
        self.derives = node.get("derives", [])
        self.doc = node.get("doc")
        self.source_order = node.get("source_order")

    def fields_in_source_order(self) -> list[Field]:
        if self.fields is None:
            return []
        if self.source_order is None:
            return list(filter(lambda x: not x.skip_source_order(), self.fields))

        fields = []
        for field_name in self.source_order:
            field = None
            for field in self.fields:
                if field.skip_source_order():
                    continue
                if field.name == field_name:
                    field = field
                    break
            fields.append(field)
        return fields


@dataclass
class Field:
    name: str
    ty: str
    _skip_visit: bool
    is_annotation: bool
    parsed_ty: FieldType

    def __init__(self, field: dict[str, Any]) -> None:
        self.name = field["name"]
        self.ty = field["type"]
        self.parsed_ty = FieldType(self.ty)
        self._skip_visit = field.get("skip_visit", False)
        self.is_annotation = field.get("is_annotation", False)

    def skip_source_order(self) -> bool:
        return self._skip_visit or self.parsed_ty.inner in [
            "str",
            "ExprContext",
            "Name",
            "u32",
            "bool",
            "Number",
            "IpyEscapeKind",
        ]


# Extracts the type argument from the given rust type with AST field type syntax.
# Box<str> -> str
# Box<Expr?> -> Expr
# If the type does not have a type argument, it will return the string.
# Does not support nested types
def extract_type_argument(rust_type_str: str) -> str:
    rust_type_str = rust_type_str.replace("*", "")
    rust_type_str = rust_type_str.replace("?", "")
    rust_type_str = rust_type_str.replace("&", "")

    open_bracket_index = rust_type_str.find("<")
    if open_bracket_index == -1:
        return rust_type_str
    close_bracket_index = rust_type_str.rfind(">")
    if close_bracket_index == -1 or close_bracket_index <= open_bracket_index:
        raise ValueError(f"Brackets are not balanced for type {rust_type_str}")
    inner_type = rust_type_str[open_bracket_index + 1 : close_bracket_index].strip()
    inner_type = inner_type.replace("crate::", "")
    return inner_type


@dataclass
class FieldType:
    rule: str
    name: str
    inner: str
    seq: bool = False
    optional: bool = False
    slice_: bool = False

    def __init__(self, rule: str) -> None:
        self.rule = rule
        self.name = ""
        self.inner = extract_type_argument(rule)

        # The following cases are the limitations of this parser(and not used in the ast.toml):
        # * Rules that involve declaring a sequence with optional items e.g. Vec<Option<...>>
        last_pos = len(rule) - 1
        for i, ch in enumerate(rule):
            if ch == "?":
                if i == last_pos:
                    self.optional = True
                else:
                    raise ValueError(f"`?` must be at the end: {rule}")
            elif ch == "*":
                if self.slice_:  # The * after & is a slice
                    continue
                if i == last_pos:
                    self.seq = True
                else:
                    raise ValueError(f"`*` must be at the end: {rule}")
            elif ch == "&":
                if i == 0 and rule.endswith("*"):
                    self.slice_ = True
                else:
                    raise ValueError(
                        f"`&` must be at the start and end with `*`: {rule}"
                    )
            else:
                self.name += ch

        if self.optional and (self.seq or self.slice_):
            raise ValueError(f"optional field cannot be sequence or slice: {rule}")


# ------------------------------------------------------------------------------
# Preamble


def write_preamble(out: list[str]) -> None:
    out.append("""
    // This is a generated file. Don't modify it by hand!
    // Run `crates/ruff_python_ast/generate.py` to re-generate the file.

    use crate::name::Name;
    use crate::visitor::source_order::SourceOrderVisitor;
    """)


# ------------------------------------------------------------------------------
# Owned enum


def write_owned_enum(out: list[str], ast: Ast) -> None:
    """
    Create an enum for each group that contains an owned copy of a syntax node.

    ```rust
    pub enum TypeParam {
        TypeVar(TypeParamTypeVar),
        TypeVarTuple(TypeParamTypeVarTuple),
        ...
    }
    ```

    Also creates:
    - `impl Ranged for TypeParam`
    - `impl HasNodeIndex for TypeParam`
    - `TypeParam::visit_source_order`
    - `impl From<TypeParamTypeVar> for TypeParam`
    - `impl Ranged for TypeParamTypeVar`
    - `impl HasNodeIndex for TypeParamTypeVar`
    - `fn TypeParam::is_type_var() -> bool`

    If the `add_suffix_to_is_methods` group option is true, then the
    `is_type_var` method will be named `is_type_var_type_param`.
    """

    for group in ast.groups:
        out.append("")
        if group.doc is not None:
            write_rustdoc(out, group.doc)
        out.append("#[derive(Clone, Debug, PartialEq)]")
        out.append('#[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]')
        out.append(f"pub enum {group.owned_enum_ty} {{")
        for node in group.nodes:
            out.append(f"{node.variant}({node.ty}),")
        out.append("}")

        for node in group.nodes:
            out.append(f"""
            impl From<{node.ty}> for {group.owned_enum_ty} {{
                fn from(node: {node.ty}) -> Self {{
                    Self::{node.variant}(node)
                }}
            }}
            """)

        out.append(f"""
        impl ruff_text_size::Ranged for {group.owned_enum_ty} {{
            fn range(&self) -> ruff_text_size::TextRange {{
                match self {{
        """)
        for node in group.nodes:
            out.append(f"Self::{node.variant}(node) => node.range(),")
        out.append("""
                }
            }
        }
        """)

        out.append(f"""
        impl crate::HasNodeIndex for {group.owned_enum_ty} {{
            fn node_index(&self) -> &crate::AtomicNodeIndex {{
                match self {{
        """)
        for node in group.nodes:
            out.append(f"Self::{node.variant}(node) => node.node_index(),")
        out.append("""
                }
            }
        }
        """)

        out.append(
            "#[allow(dead_code, clippy::match_wildcard_for_single_variants)]"
        )  # Not all is_methods are used
        out.append(f"impl {group.name} {{")
        for node in group.nodes:
            is_name = to_snake_case(node.variant)
            variant_name = node.variant
            match_arm = f"Self::{variant_name}"
            if group.add_suffix_to_is_methods:
                is_name = to_snake_case(node.variant + group.name)
            if len(group.nodes) > 1:
                out.append(f"""
                    #[inline]
                    pub const fn is_{is_name}(&self) -> bool {{
                        matches!(self, {match_arm}(_))
                    }}

                    #[inline]
                    pub fn {is_name}(self) -> Option<{node.ty}> {{
                        match self {{
                            {match_arm}(val) => Some(val),
                            _ => None,
                        }}
                    }}

                    #[inline]
                    pub fn expect_{is_name}(self) -> {node.ty} {{
                        match self {{
                            {match_arm}(val) => val,
                            _ => panic!("called expect on {{self:?}}"),
                        }}
                    }}

                    #[inline]
                    pub fn as_{is_name}_mut(&mut self) -> Option<&mut {node.ty}> {{
                        match self {{
                            {match_arm}(val) => Some(val),
                            _ => None,
                        }}
                    }}

                    #[inline]
                    pub fn as_{is_name}(&self) -> Option<&{node.ty}> {{
                        match self {{
                            {match_arm}(val) => Some(val),
                            _ => None,
                        }}
                    }}
                           """)
            elif len(group.nodes) == 1:
                out.append(f"""
                    #[inline]
                    pub const fn is_{is_name}(&self) -> bool {{
                        matches!(self, {match_arm}(_))
                    }}

                    #[inline]
                    pub fn {is_name}(self) -> Option<{node.ty}> {{
                        match self {{
                            {match_arm}(val) => Some(val),
                        }}
                    }}

                    #[inline]
                    pub fn expect_{is_name}(self) -> {node.ty} {{
                        match self {{
                            {match_arm}(val) => val,
                        }}
                    }}

                    #[inline]
                    pub fn as_{is_name}_mut(&mut self) -> Option<&mut {node.ty}> {{
                        match self {{
                            {match_arm}(val) => Some(val),
                        }}
                    }}

                    #[inline]
                    pub fn as_{is_name}(&self) -> Option<&{node.ty}> {{
                        match self {{
                            {match_arm}(val) => Some(val),
                        }}
                    }}
                           """)

        out.append("}")

    for node in ast.all_nodes:
        out.append(f"""
            impl ruff_text_size::Ranged for {node.ty} {{
                fn range(&self) -> ruff_text_size::TextRange {{
                    self.range
                }}
            }}
        """)

    for node in ast.all_nodes:
        out.append(f"""
            impl crate::HasNodeIndex for {node.ty} {{
                fn node_index(&self) -> &crate::AtomicNodeIndex {{
                    &self.node_index
                }}
            }}
        """)

    for group in ast.groups:
        out.append(f"""
            impl {group.owned_enum_ty} {{
                #[allow(unused)]
                pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
                where
                    V: crate::visitor::source_order::SourceOrderVisitor<'a> + ?Sized,
                {{
                    match self {{
        """)
        for node in group.nodes:
            out.append(
                f"{group.owned_enum_ty}::{node.variant}(node) => node.visit_source_order(visitor),"
            )
        out.append("""
                    }
                }
            }
        """)


# ------------------------------------------------------------------------------
# Ref enum


def write_ref_enum(out: list[str], ast: Ast) -> None:
    """
    Create an enum for each group that contains a reference to a syntax node.

    ```rust
    pub enum TypeParamRef<'a> {
        TypeVar(&'a TypeParamTypeVar),
        TypeVarTuple(&'a TypeParamTypeVarTuple),
        ...
    }
    ```

    Also creates:
    - `impl<'a> From<&'a TypeParam> for TypeParamRef<'a>`
    - `impl<'a> From<&'a TypeParamTypeVar> for TypeParamRef<'a>`
    - `impl Ranged for TypeParamRef<'_>`
    - `impl HasNodeIndex for TypeParamRef<'_>`
    - `fn TypeParamRef::is_type_var() -> bool`

    The name of each variant can be customized via the `variant` node option. If
    the `add_suffix_to_is_methods` group option is true, then the `is_type_var`
    method will be named `is_type_var_type_param`.
    """

    for group in ast.groups:
        out.append("")
        if group.doc is not None:
            write_rustdoc(out, group.doc)
        out.append("""#[derive(Clone, Copy, Debug, PartialEq, is_macro::Is)]""")
        out.append('#[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]')
        out.append(f"""pub enum {group.ref_enum_ty}<'a> {{""")
        for node in group.nodes:
            if group.add_suffix_to_is_methods:
                is_name = to_snake_case(node.variant + group.name)
                out.append(f'#[is(name = "{is_name}")]')
            out.append(f"""{node.variant}(&'a {node.ty}),""")
        out.append("}")

        out.append(f"""
            impl<'a> From<&'a {group.owned_enum_ty}> for {group.ref_enum_ty}<'a> {{
                fn from(node: &'a {group.owned_enum_ty}) -> Self {{
                    match node {{
        """)
        for node in group.nodes:
            out.append(
                f"{group.owned_enum_ty}::{node.variant}(node) => {group.ref_enum_ty}::{node.variant}(node),"
            )
        out.append("""
                    }
                }
            }
        """)

        for node in group.nodes:
            out.append(f"""
            impl<'a> From<&'a {node.ty}> for {group.ref_enum_ty}<'a> {{
                fn from(node: &'a {node.ty}) -> Self {{
                    Self::{node.variant}(node)
                }}
            }}
            """)

        out.append(f"""
        impl ruff_text_size::Ranged for {group.ref_enum_ty}<'_> {{
            fn range(&self) -> ruff_text_size::TextRange {{
                match self {{
        """)
        for node in group.nodes:
            out.append(f"Self::{node.variant}(node) => node.range(),")
        out.append("""
                }
            }
        }
        """)

        out.append(f"""
        impl crate::HasNodeIndex for {group.ref_enum_ty}<'_> {{
            fn node_index(&self) -> &crate::AtomicNodeIndex {{
                match self {{
        """)
        for node in group.nodes:
            out.append(f"Self::{node.variant}(node) => node.node_index(),")
        out.append("""
                }
            }
        }
        """)


# ------------------------------------------------------------------------------
# AnyNodeRef


def write_anynoderef(out: list[str], ast: Ast) -> None:
    """
    Create the AnyNodeRef type.

    ```rust
    pub enum AnyNodeRef<'a> {
        ...
        TypeParamTypeVar(&'a TypeParamTypeVar),
        TypeParamTypeVarTuple(&'a TypeParamTypeVarTuple),
        ...
    }
    ```

    Also creates:
    - `impl<'a> From<&'a TypeParam> for AnyNodeRef<'a>`
    - `impl<'a> From<TypeParamRef<'a>> for AnyNodeRef<'a>`
    - `impl<'a> From<&'a TypeParamTypeVarTuple> for AnyNodeRef<'a>`
    - `impl Ranged for AnyNodeRef<'_>`
    - `impl HasNodeIndex for AnyNodeRef<'_>`
    - `fn AnyNodeRef::as_ptr(&self) -> std::ptr::NonNull<()>`
    - `fn AnyNodeRef::visit_source_order(self, visitor &mut impl SourceOrderVisitor)`
    """

    out.append("""
    /// A flattened enumeration of all AST nodes.
    #[derive(Copy, Clone, Debug, is_macro::Is, PartialEq)]
    #[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]
    pub enum AnyNodeRef<'a> {
    """)
    for node in ast.all_nodes:
        out.append(f"""{node.name}(&'a {node.ty}),""")
    out.append("""
    }
    """)

    for group in ast.groups:
        out.append(f"""
            impl<'a> From<&'a {group.owned_enum_ty}> for AnyNodeRef<'a> {{
                fn from(node: &'a {group.owned_enum_ty}) -> AnyNodeRef<'a> {{
                    match node {{
        """)
        for node in group.nodes:
            out.append(
                f"{group.owned_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
            )
        out.append("""
                    }
                }
            }
        """)

        out.append(f"""
            impl<'a> From<{group.ref_enum_ty}<'a>> for AnyNodeRef<'a> {{
                fn from(node: {group.ref_enum_ty}<'a>) -> AnyNodeRef<'a> {{
                    match node {{
        """)
        for node in group.nodes:
            out.append(
                f"{group.ref_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
            )
        out.append("""
                    }
                }
            }
        """)

        # `as_*` methods to convert from `AnyNodeRef` to e.g. `ExprRef`
        out.append(f"""
            impl<'a> AnyNodeRef<'a> {{
                pub fn as_{to_snake_case(group.ref_enum_ty)}(self) -> Option<{group.ref_enum_ty}<'a>> {{
                    match self {{
        """)
        for node in group.nodes:
            out.append(
                f"Self::{node.name}(node) => Some({group.ref_enum_ty}::{node.variant}(node)),"
            )
        out.append("""
                        _ => None,
                    }
                }
            }
        """)

    for node in ast.all_nodes:
        out.append(f"""
            impl<'a> From<&'a {node.ty}> for AnyNodeRef<'a> {{
                fn from(node: &'a {node.ty}) -> AnyNodeRef<'a> {{
                    AnyNodeRef::{node.name}(node)
                }}
            }}
        """)

    out.append("""
        impl ruff_text_size::Ranged for AnyNodeRef<'_> {
            fn range(&self) -> ruff_text_size::TextRange {
                match self {
    """)
    for node in ast.all_nodes:
        out.append(f"""AnyNodeRef::{node.name}(node) => node.range(),""")
    out.append("""
                }
            }
        }
    """)

    out.append("""
        impl crate::HasNodeIndex for AnyNodeRef<'_> {
            fn node_index(&self) -> &crate::AtomicNodeIndex {
                match self {
    """)
    for node in ast.all_nodes:
        out.append(f"""AnyNodeRef::{node.name}(node) => node.node_index(),""")
    out.append("""
                }
            }
        }
    """)

    out.append("""
        impl AnyNodeRef<'_> {
            pub fn as_ptr(&self) -> std::ptr::NonNull<()> {
                match self {
    """)
    for node in ast.all_nodes:
        out.append(
            f"AnyNodeRef::{node.name}(node) => std::ptr::NonNull::from(*node).cast(),"
        )
    out.append("""
                }
            }
        }
    """)

    out.append("""
        impl<'a> AnyNodeRef<'a> {
            pub fn visit_source_order<'b, V>(self, visitor: &mut V)
            where
                V: crate::visitor::source_order::SourceOrderVisitor<'b> + ?Sized,
                'a: 'b,
            {
                match self {
    """)
    for node in ast.all_nodes:
        out.append(
            f"AnyNodeRef::{node.name}(node) => node.visit_source_order(visitor),"
        )
    out.append("""
                }
            }
        }
    """)

    for group in ast.groups:
        out.append(f"""
        impl AnyNodeRef<'_> {{
            pub const fn is_{group.anynode_is_label}(self) -> bool {{
                matches!(self,
        """)
        for i, node in enumerate(group.nodes):
            if i > 0:
                out.append("|")
            out.append(f"""AnyNodeRef::{node.name}(_)""")
        out.append("""
                )
            }
        }
        """)


# ------------------------------------------------------------------------------
# AnyRootNodeRef


def write_root_anynoderef(out: list[str], ast: Ast) -> None:
    """
    Create the AnyRootNodeRef type.

    ```rust
    pub enum AnyRootNodeRef<'a> {
        ...
        TypeParam(&'a TypeParam),
        ...
    }
    ```

    Also creates:
    - `impl<'a> From<&'a TypeParam> for AnyRootNodeRef<'a>`
    - `impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a TypeParam`
    - `impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a TypeParamVarTuple`
    - `impl Ranged for AnyRootNodeRef<'_>`
    - `impl HasNodeIndex for AnyRootNodeRef<'_>`
    - `fn AnyRootNodeRef::visit_source_order(self, visitor &mut impl SourceOrderVisitor)`
    """

    out.append("""
    /// An enumeration of all AST nodes.
    ///
    /// Unlike `AnyNodeRef`, this type does not flatten nested enums, so its variants only
    /// consist of the "root" AST node types. This is useful as it exposes references to the
    /// original enums, not just references to their inner values.
    ///
    /// For example, `AnyRootNodeRef::Mod` contains a reference to the `Mod` enum, while
    /// `AnyNodeRef` has top-level `AnyNodeRef::ModModule` and `AnyNodeRef::ModExpression`
    /// variants.
    #[derive(Copy, Clone, Debug, PartialEq)]
    #[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]
    pub enum AnyRootNodeRef<'a> {
    """)
    for group in ast.groups:
        out.append(f"""{group.name}(&'a {group.owned_enum_ty}),""")
    for node in ast.ungrouped_nodes:
        out.append(f"""{node.name}(&'a {node.ty}),""")
    out.append("""
    }
    """)

    for group in ast.groups:
        out.append(f"""
            impl<'a> From<&'a {group.owned_enum_ty}> for AnyRootNodeRef<'a> {{
                fn from(node: &'a {group.owned_enum_ty}) -> AnyRootNodeRef<'a> {{
                        AnyRootNodeRef::{group.name}(node)
                }}
            }}
        """)

        out.append(f"""
            impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a {group.owned_enum_ty} {{
                type Error = ();
                fn try_from(node: AnyRootNodeRef<'a>) -> Result<&'a {group.owned_enum_ty}, ()> {{
                    match node {{
                        AnyRootNodeRef::{group.name}(node) => Ok(node),
                        _ => Err(())
                    }}
                }}
            }}
        """)

        for node in group.nodes:
            out.append(f"""
                impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a {node.ty} {{
                    type Error = ();
                    fn try_from(node: AnyRootNodeRef<'a>) -> Result<&'a {node.ty}, ()> {{
                        match node {{
                            AnyRootNodeRef::{group.name}({group.owned_enum_ty}::{node.variant}(node)) => Ok(node),
                            _ => Err(())
                        }}
                    }}
                }}
            """)

    for node in ast.ungrouped_nodes:
        out.append(f"""
            impl<'a> From<&'a {node.ty}> for AnyRootNodeRef<'a> {{
                fn from(node: &'a {node.ty}) -> AnyRootNodeRef<'a> {{
                    AnyRootNodeRef::{node.name}(node)
                }}
            }}
        """)

        out.append(f"""
            impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a {node.ty} {{
                type Error = ();
                fn try_from(node: AnyRootNodeRef<'a>) -> Result<&'a {node.ty}, ()> {{
                    match node {{
                        AnyRootNodeRef::{node.name}(node) => Ok(node),
                        _ => Err(())
                    }}
                }}
            }}
        """)

    out.append("""
        impl ruff_text_size::Ranged for AnyRootNodeRef<'_> {
            fn range(&self) -> ruff_text_size::TextRange {
                match self {
    """)
    for group in ast.groups:
        out.append(f"""AnyRootNodeRef::{group.name}(node) => node.range(),""")
    for node in ast.ungrouped_nodes:
        out.append(f"""AnyRootNodeRef::{node.name}(node) => node.range(),""")
    out.append("""
                }
            }
        }
    """)

    out.append("""
        impl crate::HasNodeIndex for AnyRootNodeRef<'_> {
            fn node_index(&self) -> &crate::AtomicNodeIndex {
                match self {
    """)
    for group in ast.groups:
        out.append(f"""AnyRootNodeRef::{group.name}(node) => node.node_index(),""")
    for node in ast.ungrouped_nodes:
        out.append(f"""AnyRootNodeRef::{node.name}(node) => node.node_index(),""")
    out.append("""
                }
            }
        }
    """)

    out.append("""
        impl<'a> AnyRootNodeRef<'a> {
            pub fn visit_source_order<'b, V>(self, visitor: &mut V)
            where
                V: crate::visitor::source_order::SourceOrderVisitor<'b> + ?Sized,
                'a: 'b,
            {
                match self {
    """)
    for group in ast.groups:
        out.append(
            f"""AnyRootNodeRef::{group.name}(node) => node.visit_source_order(visitor),"""
        )
    for node in ast.ungrouped_nodes:
        out.append(
            f"""AnyRootNodeRef::{node.name}(node) => node.visit_source_order(visitor),"""
        )
    out.append("""
                }
            }
        }
    """)


# ------------------------------------------------------------------------------
# NodeKind


def write_nodekind(out: list[str], ast: Ast) -> None:
    """
    Create the NodeKind type.

    ```rust
    pub enum NodeKind {
        ...
        TypeParamTypeVar,
        TypeParamTypeVarTuple,
        ...
    }

    Also creates:
    - `fn AnyNodeRef::kind(self) -> NodeKind`
    ```
    """

    out.append("""
    #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
    pub enum NodeKind {
    """)
    for node in ast.all_nodes:
        out.append(f"""{node.name},""")
    out.append("""
    }
    """)

    out.append("""
    impl AnyNodeRef<'_> {
        pub const fn kind(self) -> NodeKind {
            match self {
    """)
    for node in ast.all_nodes:
        out.append(f"""AnyNodeRef::{node.name}(_) => NodeKind::{node.name},""")
    out.append("""
            }
        }
    }
    """)


# ------------------------------------------------------------------------------
# Node structs


def write_node(out: list[str], ast: Ast) -> None:
    group_names = [group.name for group in ast.groups]
    for group in ast.groups:
        for node in group.nodes:
            if node.fields is None:
                continue
            if node.doc is not None:
                write_rustdoc(out, node.doc)
            out.append(
                "#[derive(Clone, Debug, PartialEq"
                + "".join(f", {derive}" for derive in node.derives)
                + ")]"
            )
            out.append('#[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]')
            name = node.name
            out.append(f"pub struct {name} {{")
            out.append("pub node_index: crate::AtomicNodeIndex,")
            out.append("pub range: ruff_text_size::TextRange,")
            for field in node.fields:
                field_str = f"pub {field.name}: "
                ty = field.parsed_ty

                rust_ty = f"{field.parsed_ty.name}"
                if ty.name in types_requiring_crate_prefix:
                    rust_ty = f"crate::{rust_ty}"
                if ty.slice_:
                    rust_ty = f"[{rust_ty}]"
                if (ty.name in group_names or ty.slice_) and ty.seq is False:
                    rust_ty = f"Box<{rust_ty}>"

                if ty.seq:
                    rust_ty = f"Vec<{rust_ty}>"
                elif ty.optional:
                    rust_ty = f"Option<{rust_ty}>"

                field_str += rust_ty + ","
                out.append(field_str)
            out.append("}")
            out.append("")


# ------------------------------------------------------------------------------
# Source order visitor


def write_source_order(out: list[str], ast: Ast) -> None:
    for group in ast.groups:
        for node in group.nodes:
            if node.fields is None or node.custom_source_order:
                continue
            name = node.name
            fields_list = ""
            body = ""

            for field in node.fields:
                if field.skip_source_order():
                    fields_list += f"{field.name}: _,\n"
                else:
                    fields_list += f"{field.name},\n"
            fields_list += "range: _,\n"
            fields_list += "node_index: _,\n"

            for field in node.fields_in_source_order():
                visitor_name = (
                    type_to_visitor_function.get(
                        field.parsed_ty.inner, VisitorInfo("")
                    ).name
                    or f"visit_{to_snake_case(field.parsed_ty.inner)}"
                )
                visits_sequence = type_to_visitor_function.get(
                    field.parsed_ty.inner, VisitorInfo("")
                ).accepts_sequence

                if field.is_annotation:
                    visitor_name = "visit_annotation"

                if field.parsed_ty.optional:
                    body += f"""
                            if let Some({field.name}) = {field.name} {{
                                visitor.{visitor_name}({field.name});
                            }}\n
                      """
                elif not visits_sequence and field.parsed_ty.seq:
                    body += f"""
                            for elm in {field.name} {{
                                visitor.{visitor_name}(elm);
                            }}
                     """
                else:
                    body += f"visitor.{visitor_name}({field.name});\n"

            visitor_arg_name = "visitor"
            if len(node.fields_in_source_order()) == 0:
                visitor_arg_name = "_"

            out.append(f"""
impl {name} {{
    pub(crate) fn visit_source_order<'a, V>(&'a self, {visitor_arg_name}: &mut V)
    where
        V: SourceOrderVisitor<'a> + ?Sized,
    {{
        let {name} {{
            {fields_list}
        }} = self;
        {body}
    }}
}}
        """)


# ------------------------------------------------------------------------------
# Format and write output


def generate(ast: Ast) -> list[str]:
    out = []
    write_preamble(out)
    write_owned_enum(out, ast)
    write_ref_enum(out, ast)
    write_anynoderef(out, ast)
    write_root_anynoderef(out, ast)
    write_nodekind(out, ast)
    write_node(out, ast)
    write_source_order(out, ast)
    return out


def write_output(root: Path, out: list[str]) -> None:
    out_path = root.joinpath("crates", "ruff_python_ast", "src", "generated.rs")
    out_path.write_text(rustfmt("\n".join(out)))


# ------------------------------------------------------------------------------
# Main


def main() -> None:
    root = Path(
        check_output(["git", "rev-parse", "--show-toplevel"], text=True).strip()
    )
    ast = load_ast(root)
    out = generate(ast)
    write_output(root, out)


if __name__ == "__main__":
    main()
